{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm.autonotebook import tqdm\n",
    "\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "from mindspore import Parameter, Tensor\n",
    "from mindspore.nn import AdamWeightDecay as AdamW\n",
    "import mindspore.dataset as ds\n",
    "from mindspore.dataset import GeneratorDataset, transforms\n",
    "from mindspore.dataset.text import Vocab as msVocab\n",
    "\n",
    "from mindnlp.modules import CRF\n",
    "from mindnlp.transformers import AutoModel, AutoTokenizer\n",
    "from mindnlp.engine import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def seed_everything(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    mindspore.set_seed(seed)\n",
    "    mindspore.dataset.config.set_seed(seed)\n",
    "\n",
    "# 读取文本，返回词典，索引表，句子，标签\n",
    "def read_data(path):\n",
    "    sentences = []\n",
    "    labels = []\n",
    "    with open(path, 'r', encoding='utf-8') as f:\n",
    "        sent = []\n",
    "        label = []\n",
    "        for line in f:\n",
    "            parts = line.split()\n",
    "            if len(parts) == 0:\n",
    "                if len(sent) != 0:\n",
    "                    sentences.append(sent)\n",
    "                    labels.append(label)\n",
    "                sent = []\n",
    "                label = []\n",
    "            else:\n",
    "                sent.append(parts[0])\n",
    "                label.append(parts[-1])\n",
    "                \n",
    "    return (sentences, labels)\n",
    "\n",
    "def read_vocab(path):\n",
    "    vocab_list = []\n",
    "    with open(path, 'r', encoding='utf-8') as f:\n",
    "        for word in f:\n",
    "            vocab_list.append(word.strip())\n",
    "    return vocab_list\n",
    "\n",
    "def get_entity(decode):\n",
    "    starting=False\n",
    "    p_ans=[]\n",
    "    for i,label in enumerate(decode):\n",
    "        if label > 0:\n",
    "            if label%2==1:\n",
    "                starting=True\n",
    "                p_ans.append(([i],labels_text_mp[label//2]))\n",
    "            elif starting:\n",
    "                p_ans[-1][0].append(i)\n",
    "        else:\n",
    "            starting=False\n",
    "    return p_ans\n",
    "\n",
    "# 处理数据\n",
    "class Feature(object):\n",
    "    def __init__(self, sent, label):\n",
    "        self.sent = sent\n",
    "        label = [LABEL_MAP[c] for c in label]\n",
    "        self.token_ids = list(tokenizer(' '.join(sent)).input_ids)\n",
    "        self.seq_length = len(self.token_ids) if len(self.token_ids) - 2 < max_Len else max_Len + 2\n",
    "        offset = tokenizer(' '.join(sent), return_offsets_mapping=True).offset_mapping\n",
    "        self.labels = self.get_labels(offset, label)\n",
    "        self.labels = [0] + self.labels[:max_Len] + [0]\n",
    "        self.labels = self.labels + [0]*(max_Len - len(self.labels) + 2)\n",
    "        \n",
    "        self.token_ids = [101] + self.token_ids[1:-1][:max_Len] + [102]\n",
    "        self.token_ids = self.token_ids + [0]*(max_Len - len(self.token_ids) + 2)\n",
    "        self.entity = get_entity(self.labels)\n",
    "\n",
    "    def get_labels(self, offset_mapping, label):\n",
    "        sent_len, count, index = 0, 0, 0\n",
    "        label_new = []\n",
    "        for l, r in offset_mapping:\n",
    "            if l != 0 or r != 0:\n",
    "                if count == sent_len:\n",
    "                    sent_len += len(self.sent[index])\n",
    "                    index += 1\n",
    "                count += r - l\n",
    "                label_new.append(label[index-1])\n",
    "                \n",
    "        return label_new\n",
    "\n",
    "class GetDatasetGenerator:\n",
    "    def __init__(self, path):\n",
    "        data = read_data(path)\n",
    "        self.features = [Feature(data[0][i], data[1][i]) for i in range(len(data[0]))]\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.features)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        feature = self.features[index]\n",
    "        token_ids = feature.token_ids\n",
    "        labels = feature.labels\n",
    "        \n",
    "        return token_ids, feature.seq_length, labels\n",
    "\n",
    "def process_dataset(source, batch_size, shuffle):\n",
    "    dataset = ds.GeneratorDataset(source, [\"ids\", \"seq_length\", \"labels\"], shuffle=shuffle)\n",
    "    dataset = dataset.batch(batch_size=batch_size)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "def debug_dataset(dataset):\n",
    "    dataset = dataset.batch(batch_size=16)\n",
    "    for data in dataset.create_dict_iterator():\n",
    "        print(data[\"data\"].shape, data[\"label\"].shape)\n",
    "        break\n",
    "        \n",
    "def get_metric(P_ans, valid):\n",
    "    predict_score = 0 # 预测正确个数\n",
    "    predict_number = 0 # 预测结果个数\n",
    "    totol_number = 0 # 标签个数\n",
    "    for i in range(len(P_ans)):\n",
    "        predict_number += len(P_ans[i])\n",
    "        totol_number += len(valid.features[i].entity)\n",
    "        pred_true = [x for x in valid.features[i].entity if x in P_ans[i]]\n",
    "        predict_score += len(pred_true)\n",
    "    P = predict_score/predict_number if predict_number>0 else 0.\n",
    "    R = predict_score/totol_number if totol_number>0 else 0.\n",
    "    f1=(2*P*R)/(P+R) if (P+R)>0 else 0.\n",
    "    print(f'f1 = {f1}， P(准确率) = {P}, R(召回率) = {R}')\n",
    "    \n",
    "def get_optimizer(model):\n",
    "    param_optimizer = list(model.parameters_and_names())\n",
    "\n",
    "    no_decay = ['bias', 'layer_norm.bias', 'layer_norm.weight']\n",
    "    crf_p = [n for n, p in param_optimizer if str(n).find('crf') != -1]\n",
    "\n",
    "    optimizer_grouped_parameters = [\n",
    "            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and  n not in crf_p], 'weight_decay': 0.8},\n",
    "            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and n not in crf_p], 'weight_decay': 0.0},\n",
    "            {'params': [p for n, p in param_optimizer if n in crf_p], 'lr': 3e-3,'weight_decay': 0.8},\n",
    "            ]\n",
    "    optimizer = AdamW(optimizer_grouped_parameters, learning_rate=3e-5, eps=1e-8) # 学习率不宜过大，不然预测结果可能都是0\n",
    "\n",
    "    return optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Bert_LSTM_CRF(nn.Module):\n",
    "    def __init__(self, num_labels):\n",
    "        super().__init__()\n",
    "        self.num_labels = num_labels\n",
    "        self.bert_model = AutoModel.from_pretrained('bert-base-uncased')\n",
    "        hidden_size = self.bert_model.config.hidden_size\n",
    "        self.bilstm = nn.LSTM(hidden_size, hidden_size//2, batch_first=True, bidirectional=True)\n",
    "        self.crf_hidden_fc = nn.Dense(hidden_size, self.num_labels)\n",
    "        self.crf = CRF(self.num_labels, batch_first=True, reduction='mean')\n",
    "\n",
    "    def construct(self, ids, seq_length=None, labels=None):\n",
    "        attention_mask = (ids > mindspore.tensor(0))\n",
    "        output = self.bert_model(input_ids=ids, attention_mask=attention_mask)\n",
    "        lstm_feat, _ = self.bilstm(output[0])\n",
    "        emissions = self.crf_hidden_fc(lstm_feat)\n",
    "        loss_crf = self.crf(emissions, tags=labels, seq_length=seq_length)\n",
    "\n",
    "        return loss_crf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_everything(42)\n",
    "max_Len = 113\n",
    "Entity = ['PER', 'LOC', 'ORG', 'MISC']\n",
    "labels_text_mp={k:v for k,v in enumerate(Entity)}\n",
    "LABEL_MAP = {'O': 0}\n",
    "for i, e in enumerate(Entity):\n",
    "    LABEL_MAP[f'B-{e}'] = 2 * (i+1) - 1\n",
    "    LABEL_MAP[f'I-{e}'] = 2 * (i+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)\n",
    "train = GetDatasetGenerator('conll2003/train.txt')\n",
    "test = GetDatasetGenerator('conll2003/test.txt')\n",
    "dev = GetDatasetGenerator('conll2003/valid.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "epochs = 3\n",
    "batch_size = 16\n",
    "dataset_train = process_dataset(train, batch_size=batch_size, shuffle=False)\n",
    "model = Bert_LSTM_CRF(num_labels=len(Entity)*2+1)\n",
    "optimizer = get_optimizer(model)\n",
    "trainer = Trainer(network=model, train_dataset=dataset_train, optimizer=optimizer, epochs=epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.run(tgt_columns=\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测：train\n",
    "dataset_train = process_dataset(train, batch_size=batch_size, shuffle=False)\n",
    "size = dataset_train.get_dataset_size()\n",
    "steps = size\n",
    "decodes=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_train.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes.extend(list(decode))\n",
    "        t.update(1)\n",
    "\n",
    "v_pred = [get_entity(x) for x in decodes]\n",
    "get_metric(v_pred, train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测：dev\n",
    "dataset_dev = process_dataset(dev, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "size = dataset_dev.get_dataset_size()\n",
    "steps = size\n",
    "decodes=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_dev.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)#.asnumpy()\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes.extend(list(decode))\n",
    "        t.update(1)\n",
    "v_pred = [get_entity(x) for x in decodes]\n",
    "get_metric(v_pred, dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 预测：test\n",
    "dataset_test = process_dataset(test, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "size = dataset_test.get_dataset_size()\n",
    "steps = size\n",
    "decodes_pred=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_test.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes_pred.extend(list(decode))\n",
    "        t.update(1)\n",
    "        \n",
    "\n",
    "pred = [get_entity(x) for x in decodes_pred]\n",
    "get_metric(pred, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "vscode": {
   "interpreter": {
    "hash": "a62cb8bb4abcff3256df5ab1881dc7c3e7803473070698df3ff917df10adcce5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
